{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stacked Bidirectional LSTM Sentiment Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we *stack* LSTM layers to classify IMDB movie reviews by their sentiment."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/the-deep-learners/deep-learning-illustrated/blob/master/notebooks/stacked_bi_lstm_sentiment_classifier.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "from keras.datasets import imdb\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Embedding, SpatialDropout1D, LSTM\n",
    "from keras.layers.wrappers import Bidirectional \n",
    "from keras.callbacks import ModelCheckpoint\n",
    "import os\n",
    "from sklearn.metrics import roc_auc_score \n",
    "import matplotlib.pyplot as plt \n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Set hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# output directory name:\n",
    "output_dir = 'model_output/stackedLSTM'\n",
    "\n",
    "# training:\n",
    "epochs = 4\n",
    "batch_size = 128\n",
    "\n",
    "# vector-space embedding: \n",
    "n_dim = 64 \n",
    "n_unique_words = 10000 \n",
    "max_review_length = 200 \n",
    "pad_type = trunc_type = 'pre'\n",
    "drop_embed = 0.2 \n",
    "\n",
    "# LSTM layer architecture:\n",
    "n_lstm_1 = 64 # lower\n",
    "n_lstm_2 = 64 # new!\n",
    "drop_lstm = 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_valid, y_valid) = imdb.load_data(num_words=n_unique_words) # removed n_words_to_skip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Preprocess data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = pad_sequences(x_train, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)\n",
    "x_valid = pad_sequences(x_valid, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### Design neural network architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(n_unique_words, n_dim, input_length=max_review_length)) \n",
    "model.add(SpatialDropout1D(drop_embed))\n",
    "model.add(Bidirectional(LSTM(n_lstm_1, dropout=drop_lstm, \n",
    "                             return_sequences=True))) \n",
    "model.add(Bidirectional(LSTM(n_lstm_2, dropout=drop_lstm)))\n",
    "model.add(Dense(1, activation='sigmoid'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, 200, 64)           640000    \n",
      "_________________________________________________________________\n",
      "spatial_dropout1d_1 (Spatial (None, 200, 64)           0         \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection (None, 200, 128)          66048     \n",
      "_________________________________________________________________\n",
      "bidirectional_2 (Bidirection (None, 128)               98816     \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 129       \n",
      "=================================================================\n",
      "Total params: 804,993\n",
      "Trainable params: 804,993\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# LSTM layer parameters double due to both reading directions\n",
    "model.summary() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Configure model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelcheckpoint = ModelCheckpoint(filepath=output_dir+\"/weights.{epoch:02d}.hdf5\")\n",
    "if not os.path.exists(output_dir):\n",
    "    os.makedirs(output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 25000 samples, validate on 25000 samples\n",
      "Epoch 1/4\n",
      "25000/25000 [==============================] - 98s 4ms/step - loss: 0.4332 - acc: 0.7819 - val_loss: 0.3230 - val_acc: 0.8644\n",
      "Epoch 2/4\n",
      "25000/25000 [==============================] - 94s 4ms/step - loss: 0.2466 - acc: 0.9030 - val_loss: 0.2957 - val_acc: 0.8783\n",
      "Epoch 3/4\n",
      "25000/25000 [==============================] - 94s 4ms/step - loss: 0.1851 - acc: 0.9292 - val_loss: 0.3258 - val_acc: 0.8681\n",
      "Epoch 4/4\n",
      "25000/25000 [==============================] - 94s 4ms/step - loss: 0.1462 - acc: 0.9462 - val_loss: 0.3479 - val_acc: 0.8695\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f71df882198>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_valid, y_valid), callbacks=[modelcheckpoint])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_weights(output_dir+\"/weights.02.hdf5\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat = model.predict_proba(x_valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAD49JREFUeJzt3W2sZVV9x/HvT0a0PoIyGjvQXoyjdTRpJBPEmlgrhkfD+AKasbWOZtJJLLXWmrbYvsCoJNgnrKliR6FFYx0oNWUitITyENvGGR3EUoESpkBhCpWrA9iW+DD674uzoBe8M3ffmXPP8bC+n+Tm7L322mev/9zL/d299j6bVBWSpP48ZdoDkCRNhwEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6tSqaQ/gQI466qiam5ub9jCkH/Xt20evz3nZdMchLeLGG2/8ZlWtXqrfj3UAzM3NsWvXrmkPQ/pR//D60esbb5jmKKRFJfmPIf2cApKkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE79WH8S+FDNnXPlVI579/mnT+W4krQcngFIUqcMAEnqlAEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTgwIgyXuS3JLk60k+l+TpSY5NsjPJHUkuTXJ46/u0tr67bZ9b8D7va+23Jzl5ZUqSJA2xZAAkWQP8BrC+ql4JHAZsBD4MXFBVa4EHgc1tl83Ag1X1EuCC1o8k69p+rwBOAT6e5LDxliNJGmrVMvr9RJLvA88A7gfeAPxS234J8H7gQmBDWwa4HPizJGnt26rqu8BdSXYDxwNfOvQyJGn85s65cmrHvvv801f8GEueAVTVfwJ/BNzD6Bf/w8CNwENVta912wOsactrgHvbvvta/+cvbF9kH0nShA2ZAjqS0V/vxwI/CTwTOHWRrvXoLvvZtr/2Jx5vS5JdSXbNz88vNTxJ0kEachH4jcBdVTVfVd8HPg/8HHBEkkenkI4G7mvLe4BjANr25wJ7F7Yvss9jqmprVa2vqvWrV68+iJIkSUMMCYB7gBOSPKPN5Z8I3ApcD5zZ+mwCrmjL29s6bft1VVWtfWO7S+hYYC3w5fGUIUlariUvAlfVziSXA18F9gE3AVuBK4FtST7U2i5qu1wEfKZd5N3L6M4fquqWJJcxCo99wNlV9YMx1yNJGmjQXUBVdS5w7hOa72R0F88T+34HOGs/73MecN4yxyhJWgF+EliSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTgwIgyRFJLk/yb0luS/KaJM9Lck2SO9rrka1vknw0ye4kNyc5bsH7bGr970iyaaWKkiQtbegZwJ8Cf19VPwP8LHAbcA5wbVWtBa5t6wCnAmvb1xbgQoAkzwPOBV4NHA+c+2hoSJImb8kASPIc4HXARQBV9b2qegjYAFzSul0CvLktbwA+XSM7gCOSvAg4GbimqvZW1YPANcApY61GkjTYkDOAFwPzwF8kuSnJp5I8E3hhVd0P0F5f0PqvAe5dsP+e1ra/9sdJsiXJriS75ufnl12QJGmYIQGwCjgOuLCqXgX8L/8/3bOYLNJWB2h/fEPV1qpaX1XrV69ePWB4kqSDMSQA9gB7qmpnW7+cUSB8o03t0F4fWND/mAX7Hw3cd4B2SdIULBkAVfVfwL1JXtaaTgRuBbYDj97Jswm4oi1vB97W7gY6AXi4TRFdDZyU5Mh28fek1iZJmoJVA/u9C/hsksOBO4F3MAqPy5JsBu4Bzmp9rwJOA3YDj7S+VNXeJB8EvtL6faCq9o6lCknSsg0KgKr6GrB+kU0nLtK3gLP38z4XAxcvZ4CSpJXhJ4ElqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1KnBAZDksCQ3JflCWz82yc4kdyS5NMnhrf1pbX132z634D3e19pvT3LyuIuRJA23nDOAdwO3LVj/MHBBVa0FHgQ2t/bNwINV9RLggtaPJOuAjcArgFOAjyc57NCGL0k6WIMCIMnRwOnAp9p6gDcAl7culwBvbssb2jpt+4mt/wZgW1V9t6ruAnYDx4+jCEnS8g09A/gI8DvAD9v684GHqmpfW98DrGnLa4B7Adr2h1v/x9oX2UeSNGFLBkCSNwEPVNWNC5sX6VpLbDvQPguPtyXJriS75ufnlxqeJOkgDTkDeC1wRpK7gW2Mpn4+AhyRZFXrczRwX1veAxwD0LY/F9i7sH2RfR5TVVuran1VrV+9evWyC5IkDbNkAFTV+6rq6KqaY3QR97qq+mXgeuDM1m0TcEVb3t7Waduvq6pq7RvbXULHAmuBL4+tEknSsqxaust+/S6wLcmHgJuAi1r7RcBnkuxm9Jf/RoCquiXJZcCtwD7g7Kr6wSEcX5J0CJYVAFV1A3BDW76TRe7iqarvAGftZ//zgPOWO0hJ0vj5SWBJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROLRkASY5Jcn2S25LckuTdrf15Sa5Jckd7PbK1J8lHk+xOcnOS4xa816bW/44km1auLEnSUoacAewD3ltVLwdOAM5Osg44B7i2qtYC17Z1gFOBte1rC3AhjAIDOBd4NXA8cO6joSFJmrwlA6Cq7q+qr7bl/wZuA9YAG4BLWrdLgDe35Q3Ap2tkB3BEkhcBJwPXVNXeqnoQuAY4ZazVSJIGW9Y1gCRzwKuAncALq+p+GIUE8ILWbQ1w74Ld9rS2/bU/8RhbkuxKsmt+fn45w5MkLcPgAEjyLOBvgN+sqm8fqOsibXWA9sc3VG2tqvVVtX716tVDhydJWqZBAZDkqYx++X+2qj7fmr/RpnZorw+09j3AMQt2Pxq47wDtkqQpGHIXUICLgNuq6k8WbNoOPHonzybgigXtb2t3A50APNymiK4GTkpyZLv4e1JrkyRNwaoBfV4L/Arwr0m+1tp+DzgfuCzJZuAe4Ky27SrgNGA38AjwDoCq2pvkg8BXWr8PVNXesVQhSVq2JQOgqv6JxefvAU5cpH8BZ+/nvS4GLl7OACVJK8NPAktSpwwASeqUASBJnRpyEViSpmrunCunPYQnJc8AJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlJ8DWAHTumf57vNPn8pxJc0mzwAkqVMGgCR1ygCQpE55DUDSYD6T58nFMwBJ6pQBIEmdMgAkqVNeA3gSmeb8rJ9BkGaPZwCS1CkDQJI65RSQxsLHX0yOt2JqXDwDkKROeQYgHaQdd36Ljf41rhlmAGimTWs6ZNuLvzWV40rj5BSQJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE5NPACSnJLk9iS7k5wz6eNLkkYmGgBJDgM+BpwKrAPekmTdJMcgSRqZ9BnA8cDuqrqzqr4HbAM2THgMkiQmHwBrgHsXrO9pbZKkCZv046CzSFs9rkOyBdjSVv8nye0HeayjgG8e5L6zypon5DWPLb1p0oeGPr/P0Fnd+TBw8DX/9JBOkw6APcAxC9aPBu5b2KGqtgJbD/VASXZV1fpDfZ9ZYs196LFm6LPula550lNAXwHWJjk2yeHARmD7hMcgSWLCZwBVtS/JrwNXA4cBF1fVLZMcgyRpZOL/S8iqugq4agKHOuRppBlkzX3osWbos+4VrTlVtXQvSdKTjo+CkKROzXQALPVYiSRPS3Jp274zydzkRzl+A+r+rSS3Jrk5ybVJBt0S9uNs6CNEkpyZpJLM/N0iQ2pO8ovte31Lkr+a9BjHbcDP9k8luT7JTe3n+7RpjHOcklyc5IEkX9/P9iT5aPs3uTnJcWM7eFXN5Beji8j/DrwYOBz4F2DdE/r8GvCJtrwRuHTa455Q3b8APKMtv3PW6x5Sc+v3bOCLwA5g/bTHPYHv81rgJuDItv6CaY97AjVvBd7ZltcBd0973GOo+3XAccDX97P9NODvGH2O6gRg57iOPctnAEMeK7EBuKQtXw6cmGSxD6PNkiXrrqrrq+qRtrqD0ectZtnQR4h8EPgD4DuTHNwKGVLzrwIfq6oHAarqgQmPcdyG1FzAc9ryc3nC54hmUVV9Edh7gC4bgE/XyA7giCQvGsexZzkAhjxW4rE+VbUPeBh4/kRGt3KW+ziNzYz+ephlS9ac5FXAMVX1hUkObAUN+T6/FHhpkn9OsiPJKRMb3coYUvP7gbcm2cPobsJ3TWZoU7Vij9CZ+G2gY7TkYyUG9pk1g2tK8lZgPfDzKzqilXfAmpM8BbgAePukBjQBQ77PqxhNA72e0VnePyZ5ZVU9tMJjWylDan4L8JdV9cdJXgN8ptX8w5Uf3tSs2O+xWT4DWPKxEgv7JFnF6JTxQKdas2BI3SR5I/D7wBlV9d0JjW2lLFXzs4FXAjckuZvRPOn2Gb8QPPTn+4qq+n5V3QXczigQZtWQmjcDlwFU1ZeApzN6Xs6T2aD/5g/GLAfAkMdKbAc2teUzgeuqXVWZYUvW3aZD/pzRL/9ZnxeGJWquqoer6qiqmquqOUbXPc6oql3TGe5YDPn5/ltGF/xJchSjKaE7JzrK8RpS8z3AiQBJXs4oAOYnOsrJ2w68rd0NdALwcFXdP443ntkpoNrPYyWSfADYVVXbgYsYnSLuZvSX/8bpjXg8Btb9h8CzgL9u17zvqaozpjboQzSw5ieVgTVfDZyU5FbgB8BvV9W3pjfqQzOw5vcCn0zyHkbTIG+f9T/qknyO0TTeUe3axrnAUwGq6hOMrnWcBuwGHgHeMbZjz/i/nSTpIM3yFJAk6RAYAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkder/AKEoHz8NBVt7AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f71e66397b8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(y_hat)\n",
    "_ = plt.axvline(x=0.5, color='orange')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'94.87'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"{:0.2f}\".format(roc_auc_score(y_valid, y_hat)*100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
